{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "eff56dd7-9567-49b8-9986-dadac44dd933",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Boundary extension and contraction\n",
    "\n",
    "This notebook contains code for exploring boundary extension and contraction in a VAE trained on the shapes3d dataset.\n",
    "\n",
    "Tested with tensorflow 2.11.0 and Python 3.10.9."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4fc03a33-8116-4a38-8b01-6782a84a7654",
   "metadata": {},
   "source": [
    "#### Installation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "192b5b4a-55de-4177-93ab-9d0b85c9886f",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -r requirements.txt --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59ec08c9-c699-4bdc-8bcb-91ba4b9a5399",
   "metadata": {},
   "source": [
    "#### Imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d654d64-ae02-44f6-ad6f-35d61dff4375",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import zipfile\n",
    "import os\n",
    "import psutil\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import cv2\n",
    "from tensorflow import keras\n",
    "from PIL import Image, ImageOps\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow.keras.backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras import Model, Sequential, metrics, optimizers, layers\n",
    "from tensorflow.python.framework.ops import disable_eager_execution\n",
    "from utils import load_tfds_dataset\n",
    "from generative_model import encoder_network_large, decoder_network_large, VAE\n",
    "import tensorflow_datasets as tfds\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "\n",
    "tf.keras.utils.set_random_seed(123)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adb574f0-27e8-4d83-ade4-f8c5167fd22d",
   "metadata": {},
   "source": [
    "#### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5598c686-4c02-4d50-bdfa-796af8e1a86e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_tfds_dataset(dataset_name):\n",
    "    # Load the dataset\n",
    "    ds, ds_info = tfds.load(dataset_name, split='train', with_info=True, as_supervised=False)\n",
    "\n",
    "    # Convert to numpy arrays\n",
    "    images = []\n",
    "    label_scales = []\n",
    "    label_shapes = []\n",
    "    label_object_hues = []\n",
    "    label_floor_hues = []\n",
    "    label_wall_hues = []\n",
    "\n",
    "    counter=0\n",
    "\n",
    "    for item in tfds.as_numpy(ds):\n",
    "        images.append(item['image'])\n",
    "        label_scales.append(item['label_scale']),\n",
    "        label_shapes.append(item['label_shape'])\n",
    "        label_object_hues.append(item['label_object_hue'])\n",
    "        label_floor_hues.append(item['label_floor_hue'])\n",
    "        label_wall_hues.append(item['label_wall_hue'])\n",
    "        \n",
    "        counter += 1\n",
    "        if counter >= 40000:\n",
    "            break\n",
    "\n",
    "    \n",
    "    return np.array(images), np.array(label_scales), np.array(label_shapes), np.array(label_object_hues), np.array(label_floor_hues), np.array(label_wall_hues)\n",
    "\n",
    "# Load the dataset\n",
    "images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues = load_tfds_dataset('shapes3d')\n",
    "\n",
    "# Normalize the images\n",
    "images = images / 255.0\n",
    "\n",
    "# Split the dataset into training and testing sets\n",
    "train_images, test_images, train_label_scales, test_label_scales, train_label_shapes, test_label_shapes, train_label_object_hues, test_label_object_hues, train_label_floor_hues, test_label_floor_hues, train_label_wall_hues, test_label_wall_hues = train_test_split(\n",
    "    images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues, test_size=0.1, random_state=42\n",
    ")\n",
    "\n",
    "# Function to filter images based on shape and distinct colors\n",
    "def filter_images(images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues):\n",
    "    filtered_images = []\n",
    "    for img, scale, shape, obj_hue, floor_hue, wall_hue in zip(images, label_scales, label_shapes, label_object_hues, label_floor_hues, label_wall_hues):\n",
    "        if shape in [0] and obj_hue != floor_hue and obj_hue != wall_hue:\n",
    "            filtered_images.append(img)\n",
    "    return np.array(filtered_images)\n",
    "\n",
    "# Apply filters to training and testing datasets\n",
    "filtered_train_images = filter_images(train_images, train_label_scales, train_label_shapes, train_label_object_hues, train_label_floor_hues, train_label_wall_hues)\n",
    "filtered_test_images = filter_images(test_images, test_label_scales, test_label_shapes, test_label_object_hues, test_label_floor_hues, test_label_wall_hues)\n",
    "\n",
    "print(\"Filtered Training Images Shape:\", filtered_train_images.shape)\n",
    "print(\"Filtered Testing Images Shape:\", filtered_test_images.shape)\n",
    "\n",
    "test_ds = filtered_test_images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a422cf-78bc-4865-ba84-691b5a7b2f7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.shuffle(test_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e27f4ee3-3430-4a85-b730-64f3ddb19c08",
   "metadata": {},
   "source": [
    "#### Load the trained VAE:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8298c6e7-25bc-4c0b-8275-17cb4a21bf1f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "K.set_image_data_format('channels_last')\n",
    "\n",
    "latent_dim = 20\n",
    "input_shape = (64, 64, 3)\n",
    "\n",
    "encoder, z_mean, z_log_var = encoder_network_large(input_shape, latent_dim)\n",
    "decoder = decoder_network_large(latent_dim)\n",
    "\n",
    "encoder.load_weights(\"model_weights/shapes3d_encoder.h5\")\n",
    "decoder.load_weights(\"model_weights/shapes3d_decoder.h5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5df3d67c-ca91-4fbd-a7bf-c5742363ca63",
   "metadata": {},
   "source": [
    "#### Test boundary extension / contraction:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f13691a-829c-4491-aa65-f8ce1c1e1024",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def remove_border(im_as_array, border_width=5):\n",
    "    img = Image.fromarray((im_as_array*255).astype(np.uint8))\n",
    "    im_crop = ImageOps.crop(img, border=border_width)\n",
    "    new_im = im_crop.resize((64,64))\n",
    "    return np.array(new_im) / 255\n",
    "\n",
    "def add_border(img, border_width=5):\n",
    "    img = np.pad(img*255, pad_width=((border_width,border_width),\n",
    "                                     (border_width,border_width),\n",
    "                                     (0,0)), mode='edge')\n",
    "    img = Image.fromarray(img.astype(np.uint8))\n",
    "    img = img.resize((64,64))\n",
    "    return np.array(img)/255\n",
    "\n",
    "def add_noise(array, noise_factor=0.1, seed=None, gaussian=False, replacement_val=0):\n",
    "    # Replace a fraction noise_factor of pixels with replacement_val or gaussian noise\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "    shape = array.shape\n",
    "    array = array.flatten()\n",
    "    indices = np.random.choice(np.arange(array.size), replace=False,\n",
    "                               size=int(array.size * noise_factor))\n",
    "    if gaussian is True:\n",
    "        array[indices] = np.random.normal(loc=0.5, scale=1.0, size=array[indices].shape)\n",
    "    else:\n",
    "        array[indices] = replacement_val\n",
    "    array = array.reshape(shape)\n",
    "    return np.clip(array, 0.0, 1.0)\n",
    "\n",
    "def display_recalled(x_test_new, decoded_imgs, n=10):\n",
    "    plt.figure(figsize=(n*2, 4))\n",
    "    for i in range(n):\n",
    "        # display original\n",
    "        ax = plt.subplot(2, n, i + 1)\n",
    "        plt.imshow(x_test_new[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "\n",
    "        # display reconstruction\n",
    "        ax = plt.subplot(2, n, i + n + 1)\n",
    "        plt.imshow(decoded_imgs[i])\n",
    "        plt.gray()\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "    plt.show()\n",
    "\n",
    "code = Model(encoder.input, encoder.get_layer('mean').output)\n",
    "\n",
    "x_test_new = np.array([add_noise(image) for image in test_ds[0:20]])\n",
    "encoded_imgs = code.predict(x_test_new)\n",
    "decoded_imgs = decoder.predict(encoded_imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72f2a8a6",
   "metadata": {},
   "source": [
    "#### Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19cf9d3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "border_const = 5\n",
    "\n",
    "def plot_zoom_rows(ind):\n",
    "    x_test_new_remove = np.array([add_noise(remove_border(test_ds[ind], border_width=5.33*i)) for i in range(2)])\n",
    "    encoded_imgs = code.predict(x_test_new_remove)\n",
    "    decoded_imgs_remove_border = decoder.predict(encoded_imgs)\n",
    "\n",
    "    x_test_new_add = np.array([add_noise(add_border(test_ds[ind], border_width=8*i)) for i in range(2)])\n",
    "    encoded_imgs = code.predict(x_test_new_add)\n",
    "    decoded_imgs_add_border = decoder.predict(encoded_imgs)\n",
    "\n",
    "    display_recalled(x_test_new_add[::-1].tolist() + x_test_new_remove.tolist()[1:], \n",
    "                     decoded_imgs_add_border[::-1].tolist() + decoded_imgs_remove_border.tolist()[1:], n=3)\n",
    "\n",
    "for i in range (0,10):\n",
    "    plot_zoom_rows(i)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e294550-5b25-4ddf-b3fd-3150f16c9470",
   "metadata": {},
   "source": [
    "#### Measure change in object size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62cc2a13-1c00-4242-bced-208b6190193b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def segment_image(image, k=5):\n",
    "    # reshape the image to be a list of RGB pixels and convert to float32\n",
    "    pixels = image.reshape(-1, 3).astype(np.float32)\n",
    "    \n",
    "    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)\n",
    "    _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)\n",
    "    \n",
    "    # reshape to original image shape\n",
    "    segmented_image = centers[labels.flatten()].reshape(image.shape).astype(np.uint8)\n",
    "\n",
    "    return segmented_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35f665e-0174-4cfc-ad69-624178e49233",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def measure_object_height(image):\n",
    "    # Define the area of interest for detecting the object's color\n",
    "    mid_col_idx = image.shape[1] // 2\n",
    "    start_row_idx = int(image.shape[0] * 0.5)\n",
    "    end_row_idx = int(image.shape[0] * 0.8)\n",
    "\n",
    "    # Extract the middle column in the area of interest\n",
    "    detection_col = image[start_row_idx:end_row_idx, mid_col_idx]\n",
    "\n",
    "    # Determine the color of the central object\n",
    "    flat_detection_col = detection_col.reshape(-1, 3)\n",
    "    colors, counts = np.unique(flat_detection_col, axis=0, return_counts=True)\n",
    "    object_color = colors[counts.argmax()]\n",
    "\n",
    "    # Define a threshold for color similarity\n",
    "    color_threshold = 0.5\n",
    "\n",
    "    # Extract the entire middle column\n",
    "    full_mid_col = image[:, mid_col_idx]\n",
    "\n",
    "    # Calculate the color difference for each pixel in the full middle column\n",
    "    color_diff = np.sqrt(np.sum((full_mid_col - object_color) ** 2, axis=1))\n",
    "    is_object_color = color_diff < color_threshold\n",
    "\n",
    "    # Count the pixels of the object color in the full middle column\n",
    "    object_height = np.sum(is_object_color)\n",
    "\n",
    "    return object_height\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60d4d567-19d6-4112-a8a2-13b3143c2685",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # visualise some examples to check the functions above\n",
    "\n",
    "# zoom_levels = [80, 100, 120]\n",
    "# for zoom in zoom_levels:\n",
    "#     image = test_ds[4]\n",
    "#     print(zoom)\n",
    "#     border_width = abs(int((32 / (zoom / 100)) - 32))\n",
    "#     if zoom < 100:\n",
    "#         image = add_border(image, border_width=border_width)\n",
    "#         print(f\"Adding border of {border_width}\")\n",
    "#     elif zoom > 100:\n",
    "#         image = remove_border(image, border_width=border_width)\n",
    "#         print(f\"Removing border of {border_width}\")\n",
    "    \n",
    "#     # Add noise and use autoencoder\n",
    "#     encoded_img = code.predict(np.array([add_noise(image, noise_factor=0.0)]))\n",
    "#     decoded_img = decoder.predict(encoded_img)[0]\n",
    "    \n",
    "#     # Segment images\n",
    "#     segmented_input = segment_image(image * 255)\n",
    "#     segmented_output = segment_image(decoded_img * 255)\n",
    "\n",
    "#     # Measure object height\n",
    "#     input_height = measure_object_height(segmented_input)\n",
    "#     output_height = measure_object_height(segmented_output)\n",
    "#     print(input_height, output_height)\n",
    "\n",
    "#     # Plotting the images\n",
    "#     fig, axs = plt.subplots(2, 2, figsize=(3, 3))\n",
    "#     axs[0, 0].imshow(image)\n",
    "#     axs[0, 0].set_title('Input Image')\n",
    "#     axs[0, 1].imshow(segmented_input)\n",
    "#     axs[0, 1].set_title('Segmented In')\n",
    "#     axs[1, 0].imshow(decoded_img)\n",
    "#     axs[1, 0].set_title('Output Image')\n",
    "#     axs[1, 1].imshow(segmented_output)\n",
    "#     axs[1, 1].set_title('Segmented Out')\n",
    "    \n",
    "#     for ax in axs.flat:\n",
    "#         ax.axis('off')\n",
    "\n",
    "#     plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae1ea31f-f5e6-4f14-9c64-b506ddf78f36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "zoom_levels = range(80, 121, 5)\n",
    "size_changes_dict = {zoom: [] for zoom in zoom_levels}\n",
    "\n",
    "def get_size_change_and_zoom(ind):\n",
    "    changes = []\n",
    "\n",
    "    for zoom in zoom_levels:\n",
    "        # margin = (32 / ratio) - 32\n",
    "        # E.g. the margin to add for a zoom percentage of 80% (i.e. ratio of 0.8) is 8 pixels\n",
    "        border_width = abs(int((32 / (zoom / 100)) - 32))\n",
    "        image = test_ds[ind]\n",
    "        if zoom < 100:\n",
    "            image = add_border(image, border_width=border_width)\n",
    "        elif zoom > 100:\n",
    "            image = remove_border(image, border_width=border_width)\n",
    "        \n",
    "        encoded_img = code.predict(np.array([add_noise(image, noise_factor=0.1)]))\n",
    "        decoded_img = decoder.predict(encoded_img)[0]\n",
    "\n",
    "        input_height = measure_object_height(segment_image(image*255))\n",
    "        output_height = measure_object_height(segment_image(decoded_img*255))\n",
    "        print(input_height, output_height)\n",
    "\n",
    "        if input_height != 0:\n",
    "            change = (output_height - input_height) / input_height\n",
    "        else:\n",
    "            change = 0\n",
    "        changes.append(change)\n",
    "        \n",
    "    return changes, zoom_levels\n",
    "\n",
    "size_changes = []\n",
    "zoom_changes = []\n",
    "\n",
    "for i in range(500):\n",
    "    changes, zoom_levels = get_size_change_and_zoom(i)\n",
    "    size_changes.extend(changes)\n",
    "    zoom_changes.extend(zoom_levels)\n",
    "\n",
    "# Separate size changes by zoom level\n",
    "for size_change, zoom_level in zip(size_changes, zoom_changes):\n",
    "    size_changes_dict[zoom_level].append(size_change)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d97a3df2-ef89-46a7-8c7f-c3e24be114be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate means and standard deviations\n",
    "means = [np.mean(size_changes_dict[zoom]) for zoom in zoom_levels]\n",
    "std_devs = [np.std(size_changes_dict[zoom]) for zoom in zoom_levels]\n",
    "sems = [np.std(size_changes_dict[zoom]) / np.sqrt(len(size_changes_dict[zoom])) for zoom in zoom_levels]\n",
    "\n",
    "# Create bar chart with error bars\n",
    "plt.figure(figsize=(8,6))\n",
    "plt.bar(zoom_levels, means, yerr=sems, capsize=5, width=3.5)\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Zoom Level (%)', fontsize=18)\n",
    "plt.ylabel('Change in object size', fontsize=18)\n",
    "plt.axhline(y=0, color='black', linewidth=0.8) \n",
    "ax=plt.gca()\n",
    "ax.invert_xaxis()\n",
    "plt.savefig('BE.png')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7582dd4a-2220-40bb-98fd-f0d6641efc66",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = [size_changes_dict[zoom] for zoom in zoom_levels][0:200]\n",
    "# Create figure and axis\n",
    "plt.figure(figsize=(8, 6))\n",
    "\n",
    "positions = np.arange(len(data))\n",
    "width = 0.9 \n",
    "\n",
    "plt.boxplot(data, positions=positions, widths=width, patch_artist=True, \n",
    "            boxprops=dict(facecolor='blue', alpha=0.5), medianprops=dict(color='black'),\n",
    "           showfliers=False, whis=[10,90])\n",
    "\n",
    "plt.xticks(positions, zoom_levels, fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('Zoom Level (%)', fontsize=18)\n",
    "plt.ylabel('Change in object size', fontsize=18)\n",
    "plt.axhline(y=0, color='black', linewidth=0.8)\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.invert_xaxis()\n",
    "\n",
    "plt.savefig('BE_boxplot.pdf')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
